{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pynvml import *\n",
    "\n",
    "nvmlInit()\n",
    "vram = nvmlDeviceGetMemoryInfo(nvmlDeviceGetHandleByIndex(0)).free/1024.**2\n",
    "print('GPU0 Memory: %dMB' % vram)\n",
    "if vram < 8000:\n",
    "    raise Exception('GPU Memory too low')\n",
    "nvmlShutdown()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 导入必要的库\n",
    "\n",
    "我们需要导入一个叫 [captcha](https://github.com/lepture/captcha) 的库来生成验证码。\n",
    "\n",
    "我们生成验证码的字符由数字和几个运算符组成。\n",
    "\n",
    "`/usr/local/lib/python2.7/dist-packages/captcha/image.py`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from captcha.image import ImageCaptcha\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import random\n",
    "import os\n",
    "\n",
    "%matplotlib inline\n",
    "%config InlineBackend.figure_format = 'retina'\n",
    "\n",
    "import string\n",
    "digits = string.digits\n",
    "operators = '+-*'\n",
    "characters = digits + operators + '() '\n",
    "print(characters)\n",
    "\n",
    "width, height, n_len, n_class = 180, 60, 7, len(characters) + 1\n",
    "print(n_class)\n",
    "\n",
    "from IPython.display import display\n",
    "generator = ImageCaptcha(width=width, height=height, font_sizes=range(35, 56), \n",
    "                         fonts=['fonts/%s'%x for x in os.listdir('fonts') if '.tt' in x])\n",
    "display(generator.generate_image('(1+2)*3'))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 生成混合运算字符串"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate():\n",
    "    seq = ''\n",
    "    k = random.randint(0, 2)\n",
    "    \n",
    "    if k == 1:\n",
    "        seq += '('\n",
    "    seq += random.choice(digits)\n",
    "    seq += random.choice(operators)\n",
    "    if k == 2:\n",
    "        seq += '('\n",
    "    seq += random.choice(digits)\n",
    "    if k == 1:\n",
    "        seq += ')'\n",
    "    seq += random.choice(operators)\n",
    "    seq += random.choice(digits)\n",
    "    if k == 2:\n",
    "        seq += ')'\n",
    "    \n",
    "    return seq\n",
    "\n",
    "generate()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 定义 CTC Loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from keras import backend as K\n",
    "\n",
    "def ctc_lambda_func(args):\n",
    "    y_pred, labels, input_length, label_length = args\n",
    "    y_pred = y_pred[:, 2:, :]\n",
    "    return K.ctc_batch_cost(labels, y_pred, input_length, label_length)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 定义网络结构"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from keras.layers import *\n",
    "from keras.models import *\n",
    "from make_parallel import make_parallel\n",
    "rnn_size = 128\n",
    "\n",
    "input_tensor = Input((width, height, 3))\n",
    "x = input_tensor\n",
    "for i in range(3):\n",
    "    x = Conv2D(32*2**i, (3, 3), kernel_initializer='he_normal')(x)\n",
    "    x = BatchNormalization()(x)\n",
    "    x = Activation('relu')(x)\n",
    "    x = Conv2D(32*2**i, (3, 3), kernel_initializer='he_normal')(x)\n",
    "    x = BatchNormalization()(x)\n",
    "    x = Activation('relu')(x)\n",
    "    x = MaxPooling2D(pool_size=(2, 2))(x)\n",
    "\n",
    "conv_shape = x.get_shape()\n",
    "x = Reshape(target_shape=(int(conv_shape[1]), int(conv_shape[2]*conv_shape[3])))(x)\n",
    "\n",
    "x = Dense(128, kernel_initializer='he_normal')(x)\n",
    "x = BatchNormalization()(x)\n",
    "x = Activation('relu')(x)\n",
    "\n",
    "gru_1 = GRU(rnn_size, return_sequences=True, kernel_initializer='he_normal', name='gru1')(x)\n",
    "gru_1b = GRU(rnn_size, return_sequences=True, go_backwards=True, kernel_initializer='he_normal', \n",
    "             name='gru1_b')(x)\n",
    "gru1_merged = add([gru_1, gru_1b])\n",
    "\n",
    "gru_2 = GRU(rnn_size, return_sequences=True, kernel_initializer='he_normal', name='gru2')(gru1_merged)\n",
    "gru_2b = GRU(rnn_size, return_sequences=True, go_backwards=True, kernel_initializer='he_normal', \n",
    "             name='gru2_b')(gru1_merged)\n",
    "x = concatenate([gru_2, gru_2b])\n",
    "x = Dropout(0.25)(x)\n",
    "x = Dense(n_class, kernel_initializer='he_normal', activation='softmax')(x)\n",
    "base_model = Model(input=input_tensor, output=x)\n",
    "\n",
    "base_model2 = make_parallel(base_model, 4)\n",
    "\n",
    "labels = Input(name='the_labels', shape=[n_len], dtype='float32')\n",
    "input_length = Input(name='input_length', shape=(1,), dtype='int64')\n",
    "label_length = Input(name='label_length', shape=(1,), dtype='int64')\n",
    "loss_out = Lambda(ctc_lambda_func, name='ctc')([base_model2.output, labels, input_length, label_length])\n",
    "\n",
    "model = Model(inputs=(input_tensor, labels, input_length, label_length), outputs=loss_out)\n",
    "model.compile(loss={'ctc': lambda y_true, y_pred: y_pred}, optimizer='adam')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 定义数据生成器"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def gen(batch_size=128):\n",
    "    X = np.zeros((batch_size, width, height, 3), dtype=np.uint8)\n",
    "    y = np.zeros((batch_size, n_len), dtype=np.int32)\n",
    "    label_length = np.ones(batch_size)\n",
    "    while True:\n",
    "        for i in range(batch_size):\n",
    "            random_str = generate()\n",
    "            X[i] = np.array(generator.generate_image(random_str)).transpose(1, 0, 2)\n",
    "            y[i,:len(random_str)] = [characters.find(x) for x in random_str]\n",
    "            y[i,len(random_str):] = -1\n",
    "            label_length[i] = len(random_str)\n",
    "        yield [X, y, np.ones(batch_size)*int(conv_shape[1]-2), label_length], np.ones(batch_size)\n",
    "\n",
    "[X_test, y_test, _, label_length_test], _  = next(gen(1))\n",
    "plt.imshow(X_test[0].transpose(1, 0, 2))\n",
    "plt.title(''.join([characters[x] for x in y_test[0]]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 验证函数和回调函数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "import pandas as pd\n",
    "import cv2\n",
    "\n",
    "df = pd.read_csv('../../image_contest_level_1/labels.txt', sep=' ', header=None)\n",
    "n_test = 100000\n",
    "X_test = np.zeros((n_test, width, height, 3), dtype=np.uint8)\n",
    "y_test = np.zeros((n_test, n_len), dtype=np.int32)\n",
    "label_length_test = np.zeros((n_test, 1), dtype=np.int32)\n",
    "\n",
    "for i in tqdm(range(n_test)):\n",
    "    img = cv2.imread('../../image_contest_level_1/%d.png'%i)\n",
    "    X_test[i] = img[:,:,::-1].transpose(1, 0, 2)\n",
    "    random_str = df[0][i]\n",
    "    y_test[i,:len(random_str)] = [characters.find(x) for x in random_str]\n",
    "    y_test[i,len(random_str):] = -1\n",
    "    label_length_test[i] = len(random_str)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def evaluate(model):\n",
    "    y_pred = base_model2.predict(X_test, batch_size=1024)\n",
    "    shape = y_pred[:,2:,:].shape\n",
    "    out = K.get_value(K.ctc_decode(y_pred[:,2:,:], input_length=np.ones(shape[0])*shape[1])[0][0])[:, :n_len]\n",
    "    if out.shape[1] > 4:\n",
    "        return (y_test == out).all(axis=-1).mean()\n",
    "    return 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from keras.callbacks import *\n",
    "\n",
    "class Evaluate(Callback):\n",
    "    def __init__(self):\n",
    "        self.accs = []\n",
    "    \n",
    "    def on_epoch_end(self, epoch, logs=None):\n",
    "        acc = evaluate(base_model)*100\n",
    "        self.accs.append(acc)\n",
    "        print\n",
    "        print 'val_acc: %f%%'%acc\n",
    "\n",
    "evaluator = Evaluate()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 训练"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "from keras.optimizers import *\n",
    "from keras.callbacks import *\n",
    "\n",
    "#batch_size = 1024\n",
    "batch_size = 512\n",
    "model.compile(loss={'ctc': lambda y_true, y_pred: y_pred}, optimizer=Adam(1e-4))\n",
    "h = model.fit_generator(gen(batch_size), pickle_safe=True, workers=12, \n",
    "                        validation_data=([X_test, y_test, np.ones(n_test)*int(conv_shape[1]-2), label_length_test], \n",
    "                                         np.ones(n_test)), \n",
    "                        steps_per_epoch=100000/batch_size, epochs=50, \n",
    "                        callbacks=[ReduceLROnPlateau('loss'), \n",
    "                                   ModelCheckpoint('model_gru_best.h5', save_best_only=True)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "model.save('model.h5')\n",
    "plt.plot(range(len(h.history['loss'][10:])), h.history['loss'][10:])\n",
    "plt.plot(range(len(h.history['loss'][10:])), h.history['val_loss'][10:])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 可视化"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "[X_test, y_test, _, _], _  = next(gen(16))\n",
    "y_pred = base_model.predict(X_test)\n",
    "shape = y_pred[:,2:,:].shape\n",
    "out = K.get_value(K.ctc_decode(y_pred[:,2:,:], input_length=np.ones(shape[0])*shape[1])[0][0])[:, :n_len]\n",
    "\n",
    "plt.figure(figsize=(16, 8))\n",
    "for i, (img, label) in enumerate(zip(X_test, out)):\n",
    "    plt.subplot(4, 4, i+1)\n",
    "    plt.imshow(img.transpose(1, 0, 2))\n",
    "    s = ''.join([characters[x] for x in out[i] if x > -1])\n",
    "    try:\n",
    "        plt.title('%s=%d'%(s, eval(s)))\n",
    "    except:\n",
    "        plt.title(s)\n",
    "        print s\n",
    "        pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
